{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "90c6730f-5d76-450b-9788-ec883d024f57",
   "metadata": {},
   "source": [
    "# Hugging Face Transformers 微调训练入门\n",
    "\n",
    "本示例将介绍基于 Transformers 实现模型微调训练的主要流程，包括：\n",
    "- 数据集下载\n",
    "- 数据预处理\n",
    "- 训练超参数配置\n",
    "- 训练评估指标设置\n",
    "- 训练器基本介绍\n",
    "- 实战训练\n",
    "- 模型保存"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa0b1e12-1921-4438-8d5d-9760a629dcfe",
   "metadata": {},
   "source": [
    "## YelpReviewFull 数据集\n",
    "\n",
    "**Hugging Face 数据集：[ YelpReviewFull ](https://huggingface.co/datasets/yelp_review_full)**\n",
    "\n",
    "### 数据集摘要\n",
    "\n",
    "Yelp评论数据集包括来自Yelp的评论。它是从Yelp Dataset Challenge 2015数据中提取的。\n",
    "\n",
    "### 支持的任务和排行榜\n",
    "文本分类、情感分类：该数据集主要用于文本分类：给定文本，预测情感。\n",
    "\n",
    "### 语言\n",
    "这些评论主要以英语编写。\n",
    "\n",
    "### 数据集结构\n",
    "\n",
    "#### 数据实例\n",
    "一个典型的数据点包括文本和相应的标签。\n",
    "\n",
    "来自YelpReviewFull测试集的示例如下：\n",
    "\n",
    "```json\n",
    "{\n",
    "    'label': 0,\n",
    "    'text': 'I got \\'new\\' tires from them and within two weeks got a flat. I took my car to a local mechanic to see if i could get the hole patched, but they said the reason I had a flat was because the previous patch had blown - WAIT, WHAT? I just got the tire and never needed to have it patched? This was supposed to be a new tire. \\\\nI took the tire over to Flynn\\'s and they told me that someone punctured my tire, then tried to patch it. So there are resentful tire slashers? I find that very unlikely. After arguing with the guy and telling him that his logic was far fetched he said he\\'d give me a new tire \\\\\"this time\\\\\". \\\\nI will never go back to Flynn\\'s b/c of the way this guy treated me and the simple fact that they gave me a used tire!'\n",
    "}\n",
    "```\n",
    "\n",
    "#### 数据字段\n",
    "\n",
    "- 'text': 评论文本使用双引号（\"）转义，任何内部双引号都通过2个双引号（\"\"）转义。换行符使用反斜杠后跟一个 \"n\" 字符转义，即 \"\\n\"。\n",
    "- 'label': 对应于评论的分数（介于1和5之间）。\n",
    "\n",
    "#### 数据拆分\n",
    "\n",
    "Yelp评论完整星级数据集是通过随机选取每个1到5星评论的130,000个训练样本和10,000个测试样本构建的。总共有650,000个训练样本和50,000个测试样本。\n",
    "\n",
    "## 下载数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bbf72d6c-7ea5-4ee1-969a-c5060b9cb2d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "dataset = load_dataset(\"yelp_review_full\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ec6fc806-1395-42dd-8121-a6e98a95cf01",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['label', 'text'],\n",
       "        num_rows: 650000\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['label', 'text'],\n",
       "        num_rows: 50000\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c94ad529-1604-48bd-8c8d-aa2f3bca6200",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'label': 1,\n",
       " 'text': 'This Starbucks is teeny-tiny!\\\\n\\\\nSeating inside is VERY limited.  This is a Starbucks to grab and go and continue your shopping at the Waterfront.\\\\n\\\\nBaristas are friendly and fast.'}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset[\"train\"][110]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6dc45997-e391-456f-b0b9-d3193b0f6a9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import pandas as pd\n",
    "import datasets\n",
    "from IPython.display import display, HTML"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9e2ecebb-d5d1-456d-967c-842a79fdd622",
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_random_elements(dataset, num_examples=10):\n",
    "    assert num_examples <= len(dataset), \"Can't pick more elements than there are in the dataset.\"\n",
    "    picks = []\n",
    "    for _ in range(num_examples):\n",
    "        pick = random.randint(0, len(dataset)-1)\n",
    "        while pick in picks:\n",
    "            pick = random.randint(0, len(dataset)-1)\n",
    "        picks.append(pick)\n",
    "    \n",
    "    df = pd.DataFrame(dataset[picks])\n",
    "    for column, typ in dataset.features.items():\n",
    "        if isinstance(typ, datasets.ClassLabel):\n",
    "            df[column] = df[column].transform(lambda i: typ.names[i])\n",
    "    display(HTML(df.to_html()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1af560b6-7d21-499e-9b82-114be371a98a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>label</th>\n",
       "      <th>text</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1 star</td>\n",
       "      <td>Decided to get some pizza and garlic bites to go. I got the 12 pieces on the bites and wanted 2 sauces. The server assured me that there was 2 marinara sauces. There was only one. I asked if they had a garlic sauce and she said no, they have garlic sauce on them already, mine were quite dry. To top it off, the one sauce they gave me had a loose lid and all the marina sauce ended up everywhere except on the bites. The chain pizza places get a lot of criticism, but they are at least consistent. Would you pay $7 for this?</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>5 stars</td>\n",
       "      <td>This place is awesome!  Make sure you go during happy hour; you won't be disappointed!  My wife and I took our children there at 630 on Saturday night and were surprised when we got our check and was half off!   This is the first restaurant we've been to in a long time where I was content with the amount we spent.</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>4 stars</td>\n",
       "      <td>Great coffee, good service, owners are nice. Food is quite good just not 5 stars.</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1 star</td>\n",
       "      <td>The Pharmacy crew that is usually there, a stocky man with very close shaven hair, and a slender woman with long brown hair with bangs are absolutely horrible.  I have to use this pharmacy because I have many prescriptions due on different days of the month and this pharmacy is the only convenient one that takes my insurance.\\n\\nI don't have a problem with how long it takes for them to fill my prescriptions, I have a problem with their attitude.  They're rarely friendly, most of the time short toned and distracted, even at times snappish if they can't read the prescription etc.  \\n\\nI have never had a problem like this anywhere else, and I don't like writing bad reviews, but the behavior of these two is really bad, has been ongoing and I've talked with others who have had the same problems with them. \\n\\nLike I said, I really need to go to this pharmacy, it would be difficult to change.  I have multiple health issues.  It's difficult for me to stay positive, but I do it.  I don't understand how people in the health industry who must understand how sick some of the people they are paid to help are, because we don't all look as ill as we are, can't at least but a smile on their face and be positive when they are giving us our prescriptions.  The other customers I've spoken to about these two pharmacy assistants don't have serious health concerns like I do, they still don't like being treated in such a rude manner.  Please do something about this Safeway.\\n\\nThe lead Pharmacist who works with these two is great.  He used to work at the Safeway on Thunderbird and 7th Street, I was really happy to see that he was working here, unfortunately most of my interactions aren't with him.  He always has a smile on his face, and when I do interact with him he has an awesome attitude.</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5 stars</td>\n",
       "      <td>I am lucky enough to have health insurance but for those that do not they need to come and check out Community Medical Clinic. For starters when I called to make my new patient appointment I was offered a same day appointment. I almost fell over at this being offered to me being that typically you can wait a month to see a doctor at some larger medical centers. \\n\\nUpon arrival I was given the typical paperwork for new patients but noticed that the level of questions was  more personal in a good way. I could tell that these people care about my care and health. Dr. Moody came in and actually had a conversation with me about my health. He was very personable and kind. Dustin who was at the reception desk was also very kind. He took my blood, blood pressure, and asked me some questions.\\n\\nThe service was quick, personable, and genuine. Three things you don't get much in healthcare these days. If you need a primary care doctor in the Downtown Phoenix area I highly recommend going here!</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>4 stars</td>\n",
       "      <td>I just had a gel polish removal and polish change to regular,  and it was my first visit.\\nI loved the polish organization system and how clean and organized everything was. I didn't see nail dust,  nail trimmings, or dirt anywhere! !\\nMy only complaint?  The top coat they used, it reacted very poorly to the sun. It darkened my light mint polish to a muddy color, and chipped off all my tips, showing the true color underneath. \\nIf it's the normal polish, I will have to bring my own! \\nAlso, I would love if they gave me ideas for nail art to try.</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>1 star</td>\n",
       "      <td>This review is for the two employees who happen to be working at the time, not about Geico as a whole. I have never felt so unwelcome in a business as these two associates made me feel. They were so blatantly rude and acted like we were wasting their time. I went in with a couple questions and they acted like we didn't even deserve the time of day. I have never been in a geico office thanks for the WORST FIRST IMPRESSION ever. Doesn't even deserve one star but I had to pick one</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>4 stars</td>\n",
       "      <td>Pretty darn good burger made to order and they have a cozy bar apart from the regular seating that has a high brow atmosphere but plenty of customers just relaxing and enjoying the food with family and friends.  This is a great place for dinner with a big group.  Oh and they're open on Sundays which is a huge plus in my book.\\n\\nUpdate: As of February '10 they have updated the rest of the seating area so it now matches the bar.  In the anteroom they've added billiards and there's darts in two different corners.\\n\\nIt's like a Dilworth Neighborhood Grille in Myers park, only the kitchen still has a lot of kinks to work out.  Went for breakfast on Saturday and even though there were only two tables in the entire place, the waitresses had to tell us about all the stuff they were out of.  If they can get their kitchen sorted, this could pick up.  It would be a shame if they gave up because it finally has a contemporary, if sporty, look and feel and you could have a party with 80 people in the dining room alone.</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>5 stars</td>\n",
       "      <td>Love this place. Anybody who throws up any of La's food just isn't accustomed to ethnic spice.</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>4 stars</td>\n",
       "      <td>The roti canai and char kway teow were great. The Hokkien char mee was OK but if I ordered it again, I'd ask for more spice. We also got a fresh coconut to slurp and munch on.\\n\\nThe staff was friendly and it was easy to get refills and order additional dishes (mmm satay) without feeling compelled to rush through dinner. We were dining a bit late on a Friday and there was no wait.\\n\\nWhen dinner was over the check came and it was such a cheap meal for a couple gluttons. We were transported to Malaysia, for the length of our meal, for much less than a plane ticket.</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_random_elements(dataset[\"train\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c9df7cd0-23cd-458f-b2b5-f025c3b9fe62",
   "metadata": {},
   "source": [
    "## 预处理数据\n",
    "\n",
    "下载数据集到本地后，使用 Tokenizer 来处理文本，对于长度不等的输入数据，可以使用填充（padding）和截断（truncation）策略来处理。\n",
    "\n",
    "Datasets 的 `map` 方法，支持一次性在整个数据集上应用预处理函数。\n",
    "\n",
    "下面使用填充到最大长度的策略，处理整个数据集："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "8bf2b342-e1dd-4ab6-ad57-28eb2513ae38",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n",
    "\n",
    "\n",
    "def tokenize_function(examples):\n",
    "    return tokenizer(examples[\"text\"], padding=\"max_length\", truncation=True)\n",
    "\n",
    "\n",
    "tokenized_datasets = dataset.map(tokenize_function, batched=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "47a415a8-cd15-4a8c-851b-9b4740ef8271",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>label</th>\n",
       "      <th>text</th>\n",
       "      <th>input_ids</th>\n",
       "      <th>token_type_ids</th>\n",
       "      <th>attention_mask</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2 star</td>\n",
       "      <td>I took takeout so my review will be somewhat limited. \\n\\nFirst impression ... you come in and the place looks tired. Not bad, but kind of like \\\"I wish you would go away so I can rest\\\" tired. \\nWouldn't want to actually eat there, but for the 10 minutes it took it was ok. Didn't try Mongolian but it looked ok.\\n\\nCan't beat the takeout. $6.99 (takeout or reinstated price?) for a 3 section styrofoam box that you can fill to your little heart's content. As long as the lid can close comfortably ... there is a sign to this effect.\\n\\nThe buffet was average. I stuck to the appetizer-type stuff and filled it up. Ultimately it was enough for 2 or 3 meals. Not bad for a set price. However, if you're craving quality Chinese, this is not it. Not by a long shot. Good for a day when you're not picky about taste (not bad, mind you) or content.\\n\\nTakeout also came with a container for soup. The hot and sour did look awesome, with large chunks of tofu, and it was one of the better ones I have tried. Upon request, the attendant also filled a large bag of crispy noodles for me to take along. They did have little cups for sauces, but the hot mustard looked plastic and congealed. Scary to eat, scary to taste so I left it there. And there was no plum sauce to mix it with.  \\n\\nI was surprised that they didn't have a regular menu also, but all buffet, all the time ... \\n\\nwill I go again ? one of life's questions, yet to be answered. And asked ...</td>\n",
       "      <td>[101, 146, 1261, 1321, 3554, 1177, 1139, 3189, 1209, 1129, 4742, 2609, 119, 165, 183, 165, 183, 2271, 11836, 1204, 8351, 119, 119, 119, 1128, 1435, 1107, 1105, 1103, 1282, 2736, 4871, 119, 1753, 2213, 117, 1133, 1912, 1104, 1176, 165, 107, 146, 3683, 1128, 1156, 1301, 1283, 1177, 146, 1169, 1832, 165, 107, 4871, 119, 165, 183, 2924, 6094, 5253, 1179, 112, 189, 1328, 1106, 2140, 3940, 1175, 117, 1133, 1111, 1103, 1275, 1904, 1122, 1261, 1122, 1108, 21534, 119, 10265, 112, 189, 2222, 19210, 1133, 1122, 1350, 21534, 119, 165, 183, 165, 183, 1658, 1389, 112, 189, 3222, ...]</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...]</td>\n",
       "      <td>[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2 star</td>\n",
       "      <td>The best description I can give about this restaurant is that it is the Panda Express of asian food and not worth the price.\\n\\nFor appetizer we had the crispy asparagus. It was just japanese crackers deepfried around asparagus. Too weird of a texture but okay to eat. Only 6 sticks. So to me, 6 asparagus heads is not worth that price.\\n\\nI got the lobster curry. Again, for the price, it was not worth. i can get thai green curry for $5.95 down the street and throw in some steamed lobster and get the same effect. \\n\\nhubby got korubuta pork chop. okay, but have had much better elsewhere.\\n\\nthe place is also really dark, to the point where you cannot read your menu. you have to put the menu an inch away from the candle on your table. lame.\\n\\nand the TI sirens show in the background is VERY annoying.\\n\\nskip this place. and TI, please don't make lame asian restaurants. it's an insult to the main demographic of your casino.</td>\n",
       "      <td>[101, 1109, 1436, 6136, 146, 1169, 1660, 1164, 1142, 4382, 1110, 1115, 1122, 1110, 1103, 6991, 1810, 5764, 1104, 1112, 1811, 2094, 1105, 1136, 3869, 1103, 3945, 119, 165, 183, 165, 183, 2271, 1766, 12647, 26883, 6198, 1195, 1125, 1103, 19501, 1183, 1112, 17482, 28026, 1116, 119, 1135, 1108, 1198, 179, 26519, 13309, 8672, 1468, 1996, 13762, 1213, 1112, 17482, 28026, 1116, 119, 6466, 6994, 1104, 170, 16117, 1133, 3008, 1106, 3940, 119, 2809, 127, 13963, 119, 1573, 1106, 1143, 117, 127, 1112, 17482, 28026, 1116, 4075, 1110, 1136, 3869, 1115, 3945, 119, 165, 183, 165, 183, 2240, 1400, 1103, ...]</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...]</td>\n",
       "      <td>[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...]</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_random_elements(tokenized_datasets[\"train\"], num_examples=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c33d153-f729-4f04-972c-a764c1cbbb8b",
   "metadata": {},
   "source": [
    "### 数据抽样\n",
    "\n",
    "使用 1000 个数据样本，在 BERT 上演示小规模训练（基于 Pytorch Trainer）\n",
    "\n",
    "`shuffle()`函数会随机重新排列列的值。如果您希望对用于洗牌数据集的算法有更多控制，可以在此函数中指定generator参数来使用不同的numpy.random.Generator。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "a17317d8-3c6a-467f-843d-87491f600db1",
   "metadata": {},
   "outputs": [],
   "source": [
    "small_train_dataset = tokenized_datasets[\"train\"].shuffle(seed=42) # .select(range(1000))\n",
    "small_eval_dataset = tokenized_datasets[\"test\"].shuffle(seed=42) # .select(range(1000))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3b65d63-2d3a-4a56-bc31-6e88a29e9dec",
   "metadata": {},
   "source": [
    "## 微调训练配置\n",
    "\n",
    "### 加载 BERT 模型\n",
    "\n",
    "警告通知我们正在丢弃一些权重（`vocab_transform` 和 `vocab_layer_norm` 层），并随机初始化其他一些权重（`pre_classifier` 和 `classifier` 层）。在微调模型情况下是绝对正常的，因为我们正在删除用于预训练模型的掩码语言建模任务的头部，并用一个新的头部替换它，对于这个新头部，我们没有预训练的权重，所以库会警告我们在用它进行推理之前应该对这个模型进行微调，而这正是我们要做的事情。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "4d2af4df-abd4-4a4b-94b6-b0e7375304ed",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "78236c18c9d04f059beed3ffc0e58e2e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModelForSequenceClassification\n",
    "\n",
    "model = AutoModelForSequenceClassification.from_pretrained(\"bert-base-cased\", num_labels=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b44014df-b52c-4c72-9e9f-54424725a473",
   "metadata": {},
   "source": [
    "### 训练超参数（TrainingArguments）\n",
    "\n",
    "完整配置参数与默认值：https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/trainer#transformers.TrainingArguments\n",
    "\n",
    "源代码定义：https://github.com/huggingface/transformers/blob/v4.36.1/src/transformers/training_args.py#L161\n",
    "\n",
    "**最重要配置：模型权重保存路径(output_dir)**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "98c01d5c-de72-4ff0-b11d-e07ac5346888",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments\n",
    "\n",
    "model_dir = \"models/bert-base-cased-finetune-yelp\"\n",
    "\n",
    "# logging_steps 默认值为500，根据我们的训练数据和步长，将其设置为100\n",
    "training_args = TrainingArguments(output_dir=model_dir,\n",
    "                                  per_device_train_batch_size=16,\n",
    "                                  num_train_epochs=5,\n",
    "                                  logging_steps=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "0ce03480-3aaa-48ea-a0c6-a177b8d8e34f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TrainingArguments(\n",
      "_n_gpu=1,\n",
      "adafactor=False,\n",
      "adam_beta1=0.9,\n",
      "adam_beta2=0.999,\n",
      "adam_epsilon=1e-08,\n",
      "auto_find_batch_size=False,\n",
      "bf16=False,\n",
      "bf16_full_eval=False,\n",
      "data_seed=None,\n",
      "dataloader_drop_last=False,\n",
      "dataloader_num_workers=0,\n",
      "dataloader_persistent_workers=False,\n",
      "dataloader_pin_memory=True,\n",
      "ddp_backend=None,\n",
      "ddp_broadcast_buffers=None,\n",
      "ddp_bucket_cap_mb=None,\n",
      "ddp_find_unused_parameters=None,\n",
      "ddp_timeout=1800,\n",
      "debug=[],\n",
      "deepspeed=None,\n",
      "disable_tqdm=False,\n",
      "dispatch_batches=None,\n",
      "do_eval=False,\n",
      "do_predict=False,\n",
      "do_train=False,\n",
      "eval_accumulation_steps=None,\n",
      "eval_delay=0,\n",
      "eval_steps=None,\n",
      "evaluation_strategy=IntervalStrategy.NO,\n",
      "fp16=False,\n",
      "fp16_backend=auto,\n",
      "fp16_full_eval=False,\n",
      "fp16_opt_level=O1,\n",
      "fsdp=[],\n",
      "fsdp_config={'min_num_params': 0, 'xla': False, 'xla_fsdp_grad_ckpt': False},\n",
      "fsdp_min_num_params=0,\n",
      "fsdp_transformer_layer_cls_to_wrap=None,\n",
      "full_determinism=False,\n",
      "gradient_accumulation_steps=1,\n",
      "gradient_checkpointing=False,\n",
      "gradient_checkpointing_kwargs=None,\n",
      "greater_is_better=None,\n",
      "group_by_length=False,\n",
      "half_precision_backend=auto,\n",
      "hub_always_push=False,\n",
      "hub_model_id=None,\n",
      "hub_private_repo=False,\n",
      "hub_strategy=HubStrategy.EVERY_SAVE,\n",
      "hub_token=<HUB_TOKEN>,\n",
      "ignore_data_skip=False,\n",
      "include_inputs_for_metrics=False,\n",
      "include_num_input_tokens_seen=False,\n",
      "include_tokens_per_second=False,\n",
      "jit_mode_eval=False,\n",
      "label_names=None,\n",
      "label_smoothing_factor=0.0,\n",
      "learning_rate=5e-05,\n",
      "length_column_name=length,\n",
      "load_best_model_at_end=False,\n",
      "local_rank=0,\n",
      "log_level=passive,\n",
      "log_level_replica=warning,\n",
      "log_on_each_node=True,\n",
      "logging_dir=models/bert-base-cased-finetune-yelp/runs/Mar26_18-42-28_WIN-FODVQ3DKG8V,\n",
      "logging_first_step=False,\n",
      "logging_nan_inf_filter=True,\n",
      "logging_steps=100,\n",
      "logging_strategy=IntervalStrategy.STEPS,\n",
      "lr_scheduler_kwargs={},\n",
      "lr_scheduler_type=SchedulerType.LINEAR,\n",
      "max_grad_norm=1.0,\n",
      "max_steps=-1,\n",
      "metric_for_best_model=None,\n",
      "mp_parameters=,\n",
      "neftune_noise_alpha=None,\n",
      "no_cuda=False,\n",
      "num_train_epochs=5,\n",
      "optim=OptimizerNames.ADAMW_TORCH,\n",
      "optim_args=None,\n",
      "output_dir=models/bert-base-cased-finetune-yelp,\n",
      "overwrite_output_dir=False,\n",
      "past_index=-1,\n",
      "per_device_eval_batch_size=8,\n",
      "per_device_train_batch_size=16,\n",
      "prediction_loss_only=False,\n",
      "push_to_hub=False,\n",
      "push_to_hub_model_id=None,\n",
      "push_to_hub_organization=None,\n",
      "push_to_hub_token=<PUSH_TO_HUB_TOKEN>,\n",
      "ray_scope=last,\n",
      "remove_unused_columns=True,\n",
      "report_to=[],\n",
      "resume_from_checkpoint=None,\n",
      "run_name=models/bert-base-cased-finetune-yelp,\n",
      "save_on_each_node=False,\n",
      "save_only_model=False,\n",
      "save_safetensors=True,\n",
      "save_steps=500,\n",
      "save_strategy=IntervalStrategy.STEPS,\n",
      "save_total_limit=None,\n",
      "seed=42,\n",
      "skip_memory_metrics=True,\n",
      "split_batches=False,\n",
      "tf32=None,\n",
      "torch_compile=False,\n",
      "torch_compile_backend=None,\n",
      "torch_compile_mode=None,\n",
      "torchdynamo=None,\n",
      "tpu_metrics_debug=False,\n",
      "tpu_num_cores=None,\n",
      "use_cpu=False,\n",
      "use_ipex=False,\n",
      "use_legacy_prediction_loop=False,\n",
      "use_mps_device=False,\n",
      "warmup_ratio=0.0,\n",
      "warmup_steps=0,\n",
      "weight_decay=0.0,\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "# 完整的超参数配置\n",
    "print(training_args)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ebd3365-d359-4ab4-a300-4717590cc240",
   "metadata": {},
   "source": [
    "### 训练过程中的指标评估（Evaluate)\n",
    "\n",
    "**[Hugging Face Evaluate 库](https://huggingface.co/docs/evaluate/index)** 支持使用一行代码，获得数十种不同领域（自然语言处理、计算机视觉、强化学习等）的评估方法。 当前支持 **完整评估指标：https://huggingface.co/evaluate-metric**\n",
    "\n",
    "训练器（Trainer）在训练过程中不会自动评估模型性能。因此，我们需要向训练器传递一个函数来计算和报告指标。 \n",
    "\n",
    "Evaluate库提供了一个简单的准确率函数，您可以使用`evaluate.load`函数加载"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "2a8ef138-5bf2-41e5-8c68-df8e11f4e98f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e529cbf720f040e49aed5d9a7f28b4d9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import numpy as np\n",
    "import evaluate\n",
    "\n",
    "metric = evaluate.load(\"accuracy\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70d406c0-56d0-4a54-9c6c-e126ab7f5254",
   "metadata": {},
   "source": [
    "\n",
    "接着，调用 `compute` 函数来计算预测的准确率。\n",
    "\n",
    "在将预测传递给 compute 函数之前，我们需要将 logits 转换为预测值（**所有Transformers 模型都返回 logits**）。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "f46d2e59-1ebf-43d2-bc86-6b57a4d24d19",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_metrics(eval_pred):\n",
    "    logits, labels = eval_pred\n",
    "    predictions = np.argmax(logits, axis=-1)\n",
    "    return metric.compute(predictions=predictions, references=labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e2feba67-9ca9-4793-9a15-3eaa426df2a1",
   "metadata": {},
   "source": [
    "#### 训练过程指标监控\n",
    "\n",
    "通常，为了监控训练过程中的评估指标变化，我们可以在`TrainingArguments`指定`evaluation_strategy`参数，以便在 epoch 结束时报告评估指标。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "afaaee18-4986-4e39-8ad9-b8d413ab4cd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "training_args = TrainingArguments(output_dir=model_dir,\n",
    "                                  evaluation_strategy=\"epoch\", \n",
    "                                  per_device_train_batch_size=16,\n",
    "                                  num_train_epochs=3,\n",
    "                                  logging_steps=30)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d47d6981-e444-4c0f-a7cb-dd7f2ba8df12",
   "metadata": {},
   "source": [
    "## 开始训练\n",
    "\n",
    "### 实例化训练器（Trainer）\n",
    "\n",
    "`kernel version` 版本问题：暂不影响本示例代码运行"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "ca1d12ac-89dc-4c30-8282-f859724c0062",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=small_train_dataset,\n",
    "    eval_dataset=small_eval_dataset,\n",
    "    compute_metrics=compute_metrics,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a833e0db-1168-4a3c-8b75-bfdcef8c5157",
   "metadata": {},
   "source": [
    "## 使用 nvidia-smi 查看 GPU 使用\n",
    "\n",
    "为了实时查看GPU使用情况，可以使用 `watch` 指令实现轮询：`watch -n 1 nvidia-smi`:\n",
    "\n",
    "```shell\n",
    "Tue Mar 26 22:21:48 2024\r\n",
    "+-----------------------------------------------------------------------------------------+\r\n",
    "| NVIDIA-SMI 551.31                 Driver Version: 551.31         CUDA Version: 12.4     |\r\n",
    "|-----------------------------------------+------------------------+----------------------+\r\n",
    "| GPU  Name                     TCC/WDDM  | Bus-Id          Disp.A | Volatile Uncorr. ECC |\r\n",
    "| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |\r\n",
    "|                                         |                        |               MIG M. |\r\n",
    "|=========================================+========================+======================|\r\n",
    "|   0  NVIDIA GeForce RTX 4080 ...  WDDM  |   00000000:03:00.0 Off |                  N/A |\r\n",
    "| 50%   70C    P2            279W /  320W |   12542MiB /  16376MiB |     92%      Default |\r\n",
    "|                                         |                        |                  N/A |\r\n",
    "+-----------------------------------------+------------------------+----------------------+\r\n",
    "\r\n",
    "+-----------------------------------------------------------------------------------------+\r\n",
    "| Processes:                                                                              |\r\n",
    "|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |\r\n",
    "|        ID   ID                                                               Usage      |\r\n",
    "|=========================================================================================|\r\n",
    "|    0   N/A  N/A      3788    C+G   ...5n1h2txyewy\\ShellExperienceHost.exe      N/A      |\r\n",
    "|    0   N/A  N/A      5172    C+G   ...5n1h2txyewy\\ShellExperienceHost.exe      N/A      |\r\n",
    "|    0   N/A  N/A     15540    C+G   C:\\Windows\\explorer.exe                     N/A      |\r\n",
    "|    0   N/A  N/A     16712    C+G   ....Search_cw5n1h2txyewy\\SearchApp.exe      N/A      |\r\n",
    "|    0   N/A  N/A     17380    C+G   ...CBS_cw5n1h2txyewy\\TextInputHost.exe      N/A      |\r\n",
    "|    0   N/A  N/A     20388    C+G   ....Search_cw5n1h2txyewy\\SearchApp.exe      N/A      |\r\n",
    "|    0   N/A  N/A     21164    C+G   C:\\Windows\\explorer.exe                     N/A      |\r\n",
    "|    0   N/A  N/A     29408    C+G   ...\\Docker\\frontend\\Docker Desktop.exe      N/A      |\r\n",
    "|    0   N/A  N/A     30376    C+G   D:\\Apifox\\Apifox.exe                        N/A      |\r\n",
    "|    0   N/A  N/A     30992    C+G   C:\\Windows\\explorer.exe                     N/A      |\r\n",
    "|    0   N/A  N/A     44760    C+G   ...CBS_cw5n1h2txyewy\\TextInputHost.exe      N/A      |\r\n",
    "|    0   N/A  N/A     50520    C+G   ...CBS_cw5n1h2txyewy\\TextInputHost.exe      N/A      |\r\n",
    "|    0   N/A  N/A     51448    C+G   ...oogle\\Chrome\\Application\\chrome.exe      N/A      |\r\n",
    "|    0   N/A  N/A     52148    C+G   ...5n1h2txyewy\\ShellExperienceHost.exe      N/A      |\r\n",
    "|    0   N/A  N/A     56100    C+G   ...oogle\\Chrome\\Application\\chrome.exe      N/A      |\r\n",
    "+-----------------------------------------------------------------------------------------+\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "accfe921-471d-481a-96da-c491cdebad0c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='121875' max='121875' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [121875/121875 11:44:08, Epoch 3/3]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Accuracy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.832800</td>\n",
       "      <td>0.764211</td>\n",
       "      <td>0.664760</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.650500</td>\n",
       "      <td>0.728333</td>\n",
       "      <td>0.688000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.598100</td>\n",
       "      <td>0.736593</td>\n",
       "      <td>0.690520</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=121875, training_loss=0.7090755784098307, metrics={'train_runtime': 42249.2956, 'train_samples_per_second': 46.155, 'train_steps_per_second': 2.885, 'total_flos': 5.130803778048e+17, 'train_loss': 0.7090755784098307, 'epoch': 3.0})"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "6d581099-37a4-4470-b051-1ada38554089",
   "metadata": {},
   "outputs": [],
   "source": [
    "small_test_dataset = tokenized_datasets[\"test\"].shuffle(seed=64).select(range(100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "ffb47eab-1370-491e-8a84-6d5347a350b2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'eval_loss': 0.8981252908706665,\n",
       " 'eval_accuracy': 0.62,\n",
       " 'eval_runtime': 0.9782,\n",
       " 'eval_samples_per_second': 102.233,\n",
       " 'eval_steps_per_second': 13.29,\n",
       " 'epoch': 3.0}"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.evaluate(small_test_dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27a55686-7c43-4ab8-a5cd-0e77f14c7c52",
   "metadata": {},
   "source": [
    "### 保存模型和训练状态\n",
    "\n",
    "- 使用 `trainer.save_model` 方法保存模型，后续可以通过 from_pretrained() 方法重新加载\n",
    "- 使用 `trainer.save_state` 方法保存训练状态"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "ad0cbc14-9ef7-450f-a1a3-4f92b6486f41",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.save_model(model_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "f6e30510-0536-49d4-8e1b-43fc25272bde",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "badf5868-2847-439d-a73e-42d1cca67b5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.save_state()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c9441ad-f65a-42b7-9016-4809c78285e9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "fd92e35d-fed7-4ff2-aa84-27b5e29b917a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# trainer.model.save_pretrained(\"./\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61828934-01da-4fc3-9e75-8d754c25dfbc",
   "metadata": {},
   "source": [
    "## Homework: 使用完整的 YelpReviewFull 数据集训练，看 Acc 最高能到多少"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ee2580a-7a5a-46ae-a28b-b41e9e838eb1",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
